/* Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 * ===================================================================================================================*/

#include "graph/symbolizer/symbolic.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "common/checker.h"
#include "attribute_group/attr_group_shape_env.h"

namespace ge {
std::string SymbolicUtils::ToString(const Expression &e) {
  auto ret = e.Str(StrType::kStrCpp);
  return (ret != nullptr) ? ret.get() : "invalid expression";
}

TriBool SymbolicUtils::StaticCheckEq(const Expression &e1, const Expression &e2) {
  return StaticCheckBool(sym::Eq(e1.Simplify(), e2.Simplify()));
}

TriBool SymbolicUtils::StaticCheckNe(const Expression &e1, const Expression &e2) {
  return StaticCheckBool(sym::Ne(e1.Simplify(), e2.Simplify()));
}

TriBool SymbolicUtils::StaticCheckLt(const Expression &e1, const Expression &e2) {
  return StaticCheckBool(sym::Lt(e1.Simplify(), e2.Simplify()));
}

TriBool SymbolicUtils::StaticCheckLe(const Expression &e1, const Expression &e2) {
  return StaticCheckBool(sym::Le(e1.Simplify(), e2.Simplify()));
}

TriBool SymbolicUtils::StaticCheckGt(const Expression &e1, const Expression &e2) {
  return StaticCheckBool(sym::Gt(e1.Simplify(), e2.Simplify()));
}

TriBool SymbolicUtils::StaticCheckGe(const Expression &e1, const Expression &e2) {
  return StaticCheckBool(sym::Ge(e1.Simplify(), e2.Simplify()));
}

TriBool SymbolicUtils::StaticCheckBool(const Expression &expr) {
  GE_ASSERT_TRUE(expr.IsBooleanExpr(), "Only boolean expr can do static check, expr: %s",
      expr.Serialize().get());
  bool value = false;
  if (expr.IsConstExpr()) {
    GE_ASSERT_TRUE(expr.GetConstValue(value));
    return value ? TriBool::kTrue : TriBool::kFalse;
  }
  if (GetCurShapeEnvContext() == nullptr) {
    GELOGW("Shape env is nullptr, cannot do static check, expr: %s", expr.Serialize().get());
    return TriBool::kUnknown;
  }
  if (GetCurShapeEnvContext()->HasSymbolInfo(expr) == TriBool::kTrue) {
    GELOGI("Find check info of expr: %s, no need simplify guard", SymbolicUtils::ToString(expr).c_str());
    return TriBool::kTrue;
  }
  const auto simplify_expr = expr.Simplify();
  value = false;
  // 化简后判断是否是常量
  if (simplify_expr.IsConstExpr()) {
    GE_ASSERT_TRUE(simplify_expr.GetConstValue(value));
    return value ? TriBool::kTrue : TriBool::kFalse;
  }
  return GetCurShapeEnvContext()->HasSymbolInfo(simplify_expr);
}
}

